import bluepysnap
import pandas as pd
import matplotlib as mpl
from scipy.special import expit
import scipy as sc
from scipy.interpolate import make_interp_spline, PPoly
from scipy.optimize import curve_fit
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt


def get_spindle_amp(sim):
    report = sim.spikes["thalamus_neurons"].spike_report.filter(group={"mtype": "Rt_RC"}).report

    times = report.index
    time_start = np.min(times)
    time_stop = np.max(times)
    time_binsize = 5
    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(times, bins=bins)
    node_count = report[["ids", "population"]].drop_duplicates().shape[0]
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)

    from scipy.signal import find_peaks

    fs = []
    lengths = []
    for step in range(16):
        low, high = 3500 + step * 1000, 4500 + step * 1000
        mask = (bin_edges[:-1] >= low) & (bin_edges[:-1] <= high)
        f = freq[mask]
        t = bins[:-1][mask]
        peaks = find_peaks(f, height=2, width=3)[0]
        if peaks.size == 0:
            fs.append(0)
            lengths.append(0)
        else:
            fs.append(np.mean(f[peaks]))
            lengths.append(t[peaks[-1]] - t[peaks[0]])
    return np.mean(fs), np.std(fs), np.mean(lengths), np.std(lengths)


def fit_func(_x, slope_lin, shift, slope_exp, shift_exp, amp_exp):
    return slope_lin * _x + shift + expit((_x - shift_exp) / slope_exp) * amp_exp


if __name__ == "__main__":

    base_path = Path("../../simulations/08_08_Cav_variants/")
    shifts = []
    shift_errs = []
    fig_all = plt.figure(figsize=(6, 3))
    ax_all = plt.gca()
    # percs = [40, 50, 60, 70, 75, 80, 85, 90, 95, 100]
    percs = [60, 70, 75, 80, 85, 90, 95, 100]

    cmap = plt.get_cmap("plasma")
    colors = cmap(np.linspace(0, 1, len(percs)))
    its = [0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.001]
    data = pd.DataFrame(index=its, columns=percs)
    for i, (c, perc) in enumerate(zip(colors, percs)):
        print("computing", perc)
        fs_mean = []
        fs_std = []
        for model, var in zip(
            [
                "Ecel",
                "intermedEcel",
                "Spp",
                "intermedSpp",
                "intermedRunaway",
                "Runaway",
            ],
            [
                1,
                4,
                2,
                5,
                6,
                3,
            ],
        ):
            sim_path = (
                base_path
                / f"variant{var}_{model}/variant{var}_{model}_SUPER_LONG_{perc}percent_CT_up_down/simulation_config.json"
            )
            if not sim_path.exists():
                continue
            sim = bluepysnap.Simulation(sim_path)
            f_mean, f_std, l_mean, l_std = get_spindle_amp(sim)
            fs_mean.append(f_mean)
            fs_std.append(f_std)

        fs_mean = np.array(fs_mean)
        fs_std = np.array(fs_std)
        data.loc[its, perc] = fs_mean
        thresh = 40
        spl = make_interp_spline(its, fs_mean - thresh, k=1)
        p = PPoly.from_spline(spl)
        shifts.append(p.roots()[0])
        # _its = (100.0 - perc) * (-1.90298708e-05) + np.array(its)
        # print(_its, perc)
        ax_all.plot(its, fs_mean, "-+", c=c)
        fig_all.savefig("scan_it2.pdf")
    data.to_csv("spindle_data.csv")

    ax_all.set_ylabel("Mean spindle freq (Hz)")
    ax_all.set_xlabel("it2")
    ax_all.set_xlim(0.0005, 0.001)
    ax_all.set_ylim(0, 70)
    norm = mpl.colors.Normalize(vmin=percs[0], vmax=percs[-1])
    fig_all.colorbar(
        mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
        ax=ax_all,
        orientation="vertical",
        shrink=0.6,
        label="CT percentage",
    )

    ax_all.axhline(thresh, c="k", ls="--")
    fig_all.tight_layout()
    fig_all.savefig("scan_it2.pdf")
